{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "cellView": "form",
        "id": "C1rAsD2L-hSO"
      },
      "outputs": [],
      "source": [
        "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
        "\n",
        "# Licensed to the Apache Software Foundation (ASF) under one\n",
        "# or more contributor license agreements. See the NOTICE file\n",
        "# distributed with this work for additional information\n",
        "# regarding copyright ownership. The ASF licenses this file\n",
        "# to you under the Apache License, Version 2.0 (the\n",
        "# \"License\"); you may not use this file except in compliance\n",
        "# with the License. You may obtain a copy of the License at\n",
        "#\n",
        "#   http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing,\n",
        "# software distributed under the License is distributed on an\n",
        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
        "# KIND, either express or implied. See the License for the\n",
        "# specific language governing permissions and limitations\n",
        "# under the License"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b6f8f3af-744e-4eaa-8a30-6d03e8e4d21e"
      },
      "source": [
        "# Apache Beam RunInference for PyTorch\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_pytorch.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_pytorch.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A8xNRyZMW1yK"
      },
      "source": [
        "This notebook demonstrates the use of the RunInference transform for PyTorch. Apache Beam includes implementations of the [ModelHandler](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.ModelHandler) class for [users of PyTorch](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.pytorch_inference.html). For more information about using RunInference, see [Get started with AI/ML pipelines](https://beam.apache.org/documentation/ml/overview/) in the Apache Beam documentation.\n",
        "\n",
        "\n",
        "This notebook illustrates common RunInference patterns, such as:\n",
        "*   Using a database with RunInference.\n",
        "*   Postprocessing results after using RunInference.\n",
        "*   Inference with multiple models in the same pipeline.\n",
        "\n",
        "\n",
        "The linear regression models used in these samples are trained on data that correspondes to the 5 and 10 times tables; that is, `y = 5x` and `y = 10x` respectively."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "299af9bb-b2fc-405c-96e7-ee0a6ae24bdd"
      },
      "source": [
        "## Dependencies\n",
        "\n",
        "The RunInference library is available in Apache Beam versions 2.40 and later.\n",
        "\n",
        "To use Pytorch RunInference API, you need to install the PyTorch module. To install PyTorch, use `pip`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "loxD-rOVchRn"
      },
      "outputs": [],
      "source": [
        "!pip install apache_beam[interactive,gcp,dataframe] --quiet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 39,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7f841596-f217-46d2-b64e-1952db4de4cb",
        "outputId": "09e0026a-cf8e-455c-9580-bfaef44683ce"
      },
      "outputs": [],
      "source": [
        "%pip install torch --quiet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 40,
      "metadata": {
        "id": "9a92e3a7-beb6-46ae-a5b0-53c15487de38"
      },
      "outputs": [],
      "source": [
        "import argparse\n",
        "import csv\n",
        "import json\n",
        "import os\n",
        "import torch\n",
        "from typing import Tuple\n",
        "\n",
        "import apache_beam as beam\n",
        "import numpy\n",
        "from apache_beam.io.gcp.bigquery import ReadFromBigQuery\n",
        "from apache_beam.ml.inference.base import KeyedModelHandler\n",
        "from apache_beam.ml.inference.base import PredictionResult\n",
        "from apache_beam.ml.inference.base import RunInference\n",
        "from apache_beam.dataframe.convert import to_pcollection\n",
        "from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor\n",
        "from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor\n",
        "from apache_beam.options.pipeline_options import PipelineOptions\n",
        "\n",
        "import warnings\n",
        "warnings.filterwarnings('ignore')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 41,
      "metadata": {
        "id": "V0E35R5Ka2cE"
      },
      "outputs": [],
      "source": [
        "from google.colab import auth\n",
        "auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 42,
      "metadata": {
        "id": "248458a6-cfd8-474d-ad0e-f37f7ae981ae"
      },
      "outputs": [],
      "source": [
        "# Constants\n",
        "project = \"<your GCP project>\" # @param {type:'string'}\n",
        "bucket = \"<your GCP bucket>\" # @param {type:'string'}\n",
        "\n",
        "# To avoid warnings, set the project.\n",
        "os.environ['GOOGLE_CLOUD_PROJECT'] = project\n",
        "\n",
        "save_model_dir_multiply_five = 'five_times_table_torch.pt'\n",
        "save_model_dir_multiply_ten = 'ten_times_table_torch.pt'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b2b7cedc-79f5-4599-8178-e5da35dba032",
        "tags": []
      },
      "source": [
        "## Create data and PyTorch models for the RunInference transform\n",
        "Create linear regression models, prepare train and test data, and train models."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "202e5a3e-4ccd-4ae3-9852-e47de0721839"
      },
      "source": [
        "### Create a linear regression model in PyTorch\n",
        "Use the following code to create a linear regression model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 43,
      "metadata": {
        "id": "68bf8bf0-f735-45ee-8ebb-a2d8bb8a6bc7"
      },
      "outputs": [],
      "source": [
        "class LinearRegression(torch.nn.Module):\n",
        "    def __init__(self, input_dim=1, output_dim=1):\n",
        "        super().__init__()\n",
        "        self.linear = torch.nn.Linear(input_dim, output_dim)  \n",
        "    def forward(self, x):\n",
        "        out = self.linear(x)\n",
        "        return out"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1918435c-0029-4eb6-8eee-bda5470eb2ff"
      },
      "source": [
        "### Prepare train and test data for an example model\n",
        "This example model is a 5 times table.\n",
        "\n",
        "* `x` contains values in the range from 0 to 99.\n",
        "* `y` is a list of 5 * `x`. \n",
        "* `value_to_predict` includes values outside of the training data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 44,
      "metadata": {
        "id": "9302917f-6200-4af4-a410-4bd6f2a070b8"
      },
      "outputs": [],
      "source": [
        "x = numpy.arange(0, 100, dtype=numpy.float32).reshape(-1, 1)\n",
        "y = (x * 5).reshape(-1, 1)\n",
        "value_to_predict = numpy.array([20, 40, 60, 90], dtype=numpy.float32).reshape(-1, 1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9dc22aec-08c3-43ab-a5ce-451cb63c485a"
      },
      "source": [
        "### Train the linear regression mode on 5 times data\n",
        "Use the following code to train your linear regression model on the 5 times table."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 45,
      "metadata": {
        "id": "0a8b7924-ff06-4584-8f41-079268387a67"
      },
      "outputs": [],
      "source": [
        "five_times_model = LinearRegression()\n",
        "optimizer = torch.optim.Adam(five_times_model.parameters())\n",
        "loss_fn = torch.nn.L1Loss()\n",
        "\n",
        "\"\"\"\n",
        "Train the five_times_model\n",
        "\"\"\"\n",
        "epochs = 10000\n",
        "tensor_x = torch.from_numpy(x)\n",
        "tensor_y = torch.from_numpy(y)\n",
        "for epoch in range(epochs):\n",
        "    y_pred = five_times_model(tensor_x)\n",
        "    loss = loss_fn(y_pred, tensor_y)\n",
        "    five_times_model.zero_grad()\n",
        "    loss.backward()\n",
        "    optimizer.step()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bd106b29-6187-42c1-9743-1666c147b5e3"
      },
      "source": [
        "Save the model using `torch.save()`, and then confirm that the saved model file exists."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 46,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "882bbada-4f6d-4370-a047-c5961e564ee8",
        "outputId": "ab7242a9-76eb-4760-d74e-c725261e2a34"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "True\n"
          ]
        }
      ],
      "source": [
        "torch.save(five_times_model.state_dict(), save_model_dir_multiply_five)\n",
        "print(os.path.exists(save_model_dir_multiply_five)) # Verify that the model is saved."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fa84cfca-83c6-4a91-aea1-3dd034c42ae0"
      },
      "source": [
        "### Prepare train and test data for a 10 times model\n",
        "This example model is a 10 times table.\n",
        "* `x` contains values in the range from 0 to 99.\n",
        "* `y` is a list of 10 * `x`. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 47,
      "metadata": {
        "id": "27e0d1f6-2c3e-4418-8fb0-b5b89ffa66d5"
      },
      "outputs": [],
      "source": [
        "x = numpy.arange(0, 100, dtype=numpy.float32).reshape(-1, 1)\n",
        "y = (x * 10).reshape(-1, 1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "24d946dc-4fe0-4030-8f6a-aa8d27fd353d"
      },
      "source": [
        "### Train the linear regression model on 10 times data\n",
        "Use the following to train your linear regression model on the 10 times table."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 48,
      "metadata": {
        "id": "2b352313-b791-48fd-9b9d-b54233176416"
      },
      "outputs": [],
      "source": [
        "ten_times_model = LinearRegression()\n",
        "optimizer = torch.optim.Adam(ten_times_model.parameters())\n",
        "loss_fn = torch.nn.L1Loss()\n",
        "\n",
        "epochs = 10000\n",
        "tensor_x = torch.from_numpy(x)\n",
        "tensor_y = torch.from_numpy(y)\n",
        "for epoch in range(epochs):\n",
        "    y_pred = ten_times_model(tensor_x)\n",
        "    loss = loss_fn(y_pred, tensor_y)\n",
        "    ten_times_model.zero_grad()\n",
        "    loss.backward()\n",
        "    optimizer.step()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6f959e3b-230b-45e2-9df3-dd1f11acacd7"
      },
      "source": [
        "Save the model using `torch.save()`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 49,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "42b2ca0f-5d44-4d15-a313-f3d56ae7f675",
        "outputId": "9cb2f268-a500-4ad5-a075-856c87b8e3be"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "True\n"
          ]
        }
      ],
      "source": [
        "torch.save(ten_times_model.state_dict(), save_model_dir_multiply_ten)\n",
        "print(os.path.exists(save_model_dir_multiply_ten)) # verify if the model is saved"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2e20efc4-13e8-46e2-9848-c0347deaa5af"
      },
      "source": [
        "## Pattern 1: RunInference for predictions\n",
        "This pattern demonstrates how to use RunInference for predictions."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1099fe94-d4cf-422e-a0d3-0cfba8af64d5"
      },
      "source": [
        "### Use RunInference within the pipeline\n",
        "\n",
        "1. Create a PyTorch model handler object by passing required arguments such as `state_dict_path`, `model_class`, and `model_params` to the `PytorchModelHandlerTensor` class.\n",
        "2. Pass the `PytorchModelHandlerTensor` object to the RunInference transform to perform predictions on unkeyed data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 50,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "e488a821-3b70-4284-96f3-ddee4dcb9d71",
        "outputId": "add9af31-1cc6-496f-a6e4-3fb185c0de25"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "PredictionResult(example=tensor([20.]), inference=tensor([102.0095], grad_fn=<UnbindBackward0>))\n",
            "PredictionResult(example=tensor([40.]), inference=tensor([201.2056], grad_fn=<UnbindBackward0>))\n",
            "PredictionResult(example=tensor([60.]), inference=tensor([300.4017], grad_fn=<UnbindBackward0>))\n",
            "PredictionResult(example=tensor([90.]), inference=tensor([449.1959], grad_fn=<UnbindBackward0>))\n"
          ]
        }
      ],
      "source": [
        "torch_five_times_model_handler = PytorchModelHandlerTensor(\n",
        "    state_dict_path=save_model_dir_multiply_five,\n",
        "    model_class=LinearRegression,\n",
        "    model_params={'input_dim': 1,\n",
        "                  'output_dim': 1}\n",
        "                  )\n",
        "pipeline = beam.Pipeline()\n",
        "\n",
        "with pipeline as p:\n",
        "      (\n",
        "      p \n",
        "      | \"ReadInputData\" >> beam.Create(value_to_predict)\n",
        "      | \"ConvertNumpyToTensor\" >> beam.Map(torch.Tensor)\n",
        "      | \"RunInferenceTorch\" >> RunInference(torch_five_times_model_handler)\n",
        "      | beam.Map(print)\n",
        "      )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9d95e69b-203f-4abb-9abb-360bdf4d769a"
      },
      "source": [
        "## Pattern 2: Postprocess RunInference results\n",
        "This pattern demonstrates how to postprocess the RunInference results.\n",
        "\n",
        "Add a `PredictionProcessor` to the pipeline after `RunInference`. `PredictionProcessor` processes the output of the `RunInference` transform."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 51,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "96f38a5a-4db0-4c39-8ce7-80d9f9911b48",
        "outputId": "b1d689a2-9336-40b2-a984-538bec888cc9"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "input is 20.0 output is 102.00947570800781\n",
            "input is 40.0 output is 201.20559692382812\n",
            "input is 60.0 output is 300.4017028808594\n",
            "input is 90.0 output is 449.1958923339844\n"
          ]
        }
      ],
      "source": [
        "class PredictionProcessor(beam.DoFn):\n",
        "  \"\"\"\n",
        "  A processor to format the output of the RunInference transform.\n",
        "  \"\"\"\n",
        "  def process(\n",
        "      self,\n",
        "      element: PredictionResult):\n",
        "    input_value = element.example\n",
        "    output_value = element.inference\n",
        "    yield (f\"input is {input_value.item()} output is {output_value.item()}\")\n",
        "\n",
        "pipeline = beam.Pipeline()\n",
        "\n",
        "with pipeline as p:\n",
        "    (\n",
        "    p\n",
        "    | \"ReadInputData\" >> beam.Create(value_to_predict)\n",
        "    | \"ConvertNumpyToTensor\" >> beam.Map(torch.Tensor)\n",
        "    | \"RunInferenceTorch\" >> RunInference(torch_five_times_model_handler)\n",
        "    | \"PostProcessPredictions\" >> beam.ParDo(PredictionProcessor())\n",
        "    | beam.Map(print)\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2be80463-cf79-481c-9d6a-81e500f1707b"
      },
      "source": [
        "## Pattern 3: Attach a key\n",
        "\n",
        "This pattern demonstrates how attach a key to allow your model to handle keyed data."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "746b67a7-3562-467f-bea3-d8cd18c14927"
      },
      "source": [
        "### Modify the model handler and post processor\n",
        "\n",
        "Modify the pipeline to read from sources like CSV files and BigQuery.\n",
        "\n",
        "In this step, you take the following actions:\n",
        "\n",
        "* To handle keyed data, wrap the `PytorchModelHandlerTensor` object around `KeyedModelHandler`.\n",
        "* Add a map transform that converts a table row into `Tuple[str, float]`.\n",
        "* Add a map transform that converts `Tuple[str, float]` to `Tuple[str, torch.Tensor]`.\n",
        "* Modify the post-inference processor to output results with the key."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 52,
      "metadata": {
        "id": "90b263fc-97a5-43dc-8874-083d7e65e96d"
      },
      "outputs": [],
      "source": [
        "class PredictionWithKeyProcessor(beam.DoFn):\n",
        "    def __init__(self):\n",
        "        beam.DoFn.__init__(self)\n",
        "\n",
        "    def process(\n",
        "          self,\n",
        "          element: Tuple[str, PredictionResult]):\n",
        "        key = element[0]\n",
        "        input_value = element[1].example\n",
        "        output_value = element[1].inference\n",
        "        yield (f\"key: {key}, input: {input_value.item()} output: {output_value.item()}\" )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f22da313-5bf8-4334-865b-bbfafc374e63"
      },
      "source": [
        "### Create a source with attached key\n",
        "This section shows how to create either a BigQuery or a CSV source with an attached key."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c9b0fb49-d605-4f26-931a-57f42b0ad253"
      },
      "source": [
        "#### Use BigQuery as the source\n",
        "Follow these steps to use BigQuery as your source."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "45ce4330-7d33-4c53-8033-f4fa02383894"
      },
      "source": [
        "To install the Google Cloud BigQuery API, use `pip`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4eb859dd-ba54-45a1-916b-5bbe4dc3f16e"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade google-cloud-bigquery --quiet"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6e869347-dd49-40be-b1e5-749699dc0d83"
      },
      "source": [
        "Create a table in BigQuery using the following snippet, which has two columns. The first column holds the key and the second column holds the test value. To use BiqQuery, you need a Google Cloud account with the BigQuery API enabled."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 54,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7mgnryX-Zlfs",
        "outputId": "6e608e98-8369-45aa-c983-e62296202c52"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Updated property [core/project].\n"
          ]
        }
      ],
      "source": [
        "!gcloud config set project $project"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 55,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "a6a984cd-2e92-4c44-821b-9bf1dd52fb7d",
        "outputId": "a50ab0fd-4f4e-4493-b506-41d3f7f08966"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<google.cloud.bigquery.table._EmptyRowIterator at 0x7f5694206650>"
            ]
          },
          "execution_count": 55,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "from google.cloud import bigquery\n",
        "\n",
        "client = bigquery.Client(project=project)\n",
        "\n",
        "# Make sure the dataset_id is unique in your project.\n",
        "dataset_id = '{project}.maths'.format(project=project)\n",
        "dataset = bigquery.Dataset(dataset_id)\n",
        "\n",
        "# Modify the location based on your project configuration.\n",
        "dataset.location = 'US'\n",
        "dataset = client.create_dataset(dataset, exists_ok=True)\n",
        "\n",
        "# Table name in the BigQuery dataset.\n",
        "table_name = 'maths_problems_1'\n",
        "\n",
        "query = \"\"\"\n",
        "    CREATE OR REPLACE TABLE\n",
        "      {project}.maths.{table} ( key STRING OPTIONS(description=\"A unique key for the maths problem\"),\n",
        "    value FLOAT64 OPTIONS(description=\"Our maths problem\" ) );\n",
        "    INSERT INTO maths.{table}\n",
        "    VALUES\n",
        "      (\"first_question\", 105.00),\n",
        "      (\"second_question\", 108.00),\n",
        "      (\"third_question\", 1000.00),\n",
        "      (\"fourth_question\", 1013.00)\n",
        "\"\"\".format(project=project, table=table_name)\n",
        "\n",
        "create_job = client.query(query)\n",
        "create_job.result()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "479c9319-3295-4288-971c-dd0f0adfdd1e"
      },
      "source": [
        "To read keyed data, use BigQuery as the pipeline source."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 56,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "34331897-23f5-4850-8974-67e522e956dc",
        "outputId": "9d2b0ba5-97a2-46bf-c9d3-e023afbd3122"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "key: third_question, input: 1000.0 output: 4962.61962890625\n",
            "key: second_question, input: 108.0 output: 538.472412109375\n",
            "key: first_question, input: 105.0 output: 523.5929565429688\n",
            "key: fourth_question, input: 1013.0 output: 5027.0966796875\n"
          ]
        }
      ],
      "source": [
        "pipeline_options = PipelineOptions().from_dictionary({'temp_location':f'gs://{bucket}/tmp',\n",
        "                                                      })\n",
        "pipeline = beam.Pipeline(options=pipeline_options)\n",
        "\n",
        "keyed_torch_five_times_model_handler = KeyedModelHandler(torch_five_times_model_handler)\n",
        "\n",
        "table_name = 'maths_problems_1'\n",
        "table_spec = f'{project}:maths.{table_name}'\n",
        "\n",
        "with pipeline as p:\n",
        "      (\n",
        "      p\n",
        "      | \"ReadFromBQ\" >> beam.io.ReadFromBigQuery(table=table_spec) \n",
        "      | \"PreprocessData\" >> beam.Map(lambda x: (x['key'], x['value']))\n",
        "      | \"ConvertNumpyToTensor\" >> beam.Map(lambda x: (x[0], torch.Tensor([x[1]])))\n",
        "      | \"RunInferenceTorch\" >> RunInference(keyed_torch_five_times_model_handler)\n",
        "      | \"PostProcessPredictions\" >> beam.ParDo(PredictionWithKeyProcessor())\n",
        "      | beam.Map(print)\n",
        "      )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "53ee7f24-5625-475a-b8cc-9c031591f304"
      },
      "source": [
        "#### Use a CSV file as the source\n",
        "Follow these steps to use a CSV file as your source."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "06bc4396-ee2e-4228-8548-f953b5020c4e"
      },
      "source": [
        "Create a CSV file with two columns: one named `key` that holds the keys, and a second named `value` that holds the test values."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 62,
      "metadata": {
        "id": "exAZjP7cYAFv"
      },
      "outputs": [],
      "source": [
        "# creates a CSV file with the values.\n",
        "csv_values = [(\"first_question\", 105.00),\n",
        "      (\"second_question\", 108.00),\n",
        "      (\"third_question\", 1000.00),\n",
        "      (\"fourth_question\", 1013.00)]\n",
        "input_csv_file = \"./maths_problem.csv\"\n",
        "\n",
        "with open(input_csv_file, 'w') as f:\n",
        "  writer = csv.writer(f)\n",
        "  writer.writerow(['key', 'value'])\n",
        "  for row in csv_values:\n",
        "    writer.writerow(row)\n",
        "\n",
        "assert os.path.exists(input_csv_file) == True"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 66,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9a054c2d-4d84-4b37-b067-1dda5347e776",
        "outputId": "2f2ea8b7-b425-48ae-e857-fe214c7eced2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "key: first_question, input: 105.0 output: 523.5929565429688\n",
            "key: second_question, input: 108.0 output: 538.472412109375\n",
            "key: third_question, input: 1000.0 output: 4962.61962890625\n",
            "key: fourth_question, input: 1013.0 output: 5027.0966796875\n"
          ]
        }
      ],
      "source": [
        "pipeline_options = PipelineOptions().from_dictionary({'temp_location':f'gs://{bucket}/tmp',\n",
        "                                                      })\n",
        "pipeline = beam.Pipeline(options=pipeline_options)\n",
        "\n",
        "keyed_torch_five_times_model_handler = KeyedModelHandler(torch_five_times_model_handler)\n",
        "\n",
        "with pipeline as p:\n",
        "  df = p | beam.dataframe.io.read_csv(input_csv_file)\n",
        "  pc = to_pcollection(df)\n",
        "  (pc\n",
        "    | \"ConvertNumpyToTensor\" >> beam.Map(lambda x: (x[0], torch.Tensor([x[1]])))\n",
        "    | \"RunInferenceTorch\" >> RunInference(keyed_torch_five_times_model_handler)\n",
        "    | \"PostProcessPredictions\" >> beam.ParDo(PredictionWithKeyProcessor())\n",
        "    | beam.Map(print)\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "742abfbb-545e-435b-8833-2410ce29d22c"
      },
      "source": [
        "## Pattern 4: Inference with multiple models in the same pipeline\n",
        "This pattern demonstrates how use inference with multiple models in the same pipeline.\n",
        "\n",
        "### Multiple models in parallel\n",
        "This section demonstrates how use inference with multiple models in parallel."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "570b2f27-3beb-4295-b926-595592289c06"
      },
      "source": [
        "Create a torch model handler for the 10 times model using `PytorchModelHandlerTensor`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 67,
      "metadata": {
        "id": "73607c45-afe1-4990-9a55-e687ed40302e"
      },
      "outputs": [],
      "source": [
        "torch_ten_times_model_handler = PytorchModelHandlerTensor(state_dict_path=save_model_dir_multiply_ten,\n",
        "                                        model_class=LinearRegression,\n",
        "                                        model_params={'input_dim': 1,\n",
        "                                                      'output_dim': 1}\n",
        "                                        )\n",
        "keyed_torch_ten_times_model_handler = KeyedModelHandler(torch_ten_times_model_handler)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "70ebed52-4ead-4cae-ac96-8cf206012ce1"
      },
      "source": [
        "In this section, the same data is run through two different models: the one that we use to multiply by 5 \n",
        "and a new model that learns to multiply by 10."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 68,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "629d070e-9902-42c9-a1e7-56c3d1864f13",
        "outputId": "0b4d7f3c-4696-422f-b031-ee5a03e90e03"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "key: third_question * 10, input: 1000.0 output: 9889.59765625\n",
            "key: second_question * 10, input: 108.0 output: 1075.4891357421875\n",
            "key: first_question * 10, input: 105.0 output: 1045.84521484375\n",
            "key: fourth_question * 10, input: 1013.0 output: 10018.0546875\n",
            "key: third_question * 5, input: 1000.0 output: 4962.61962890625\n",
            "key: second_question * 5, input: 108.0 output: 538.472412109375\n",
            "key: first_question * 5, input: 105.0 output: 523.5929565429688\n",
            "key: fourth_question * 5, input: 1013.0 output: 5027.0966796875\n"
          ]
        }
      ],
      "source": [
        "pipeline_options = PipelineOptions().from_dictionary(\n",
        "                                      {'temp_location':f'gs://{bucket}/tmp'})\n",
        "\n",
        "pipeline = beam.Pipeline(options=pipeline_options)\n",
        "\n",
        "read_from_bq = beam.io.ReadFromBigQuery(table=table_spec)\n",
        "\n",
        "with pipeline as p:\n",
        "  multiply_five = (\n",
        "      p \n",
        "      |  read_from_bq\n",
        "      | \"CreateMultiplyFiveTuple\" >> beam.Map(lambda x: ('{} {}'.format(x['key'], '* 5'), x['value']))\n",
        "      | \"ConvertNumpyToTensorFiveTuple\" >> beam.Map(lambda x: (x[0], torch.Tensor([x[1]])))\n",
        "      | \"RunInferenceTorchFiveTuple\" >> RunInference(keyed_torch_five_times_model_handler)\n",
        "  )\n",
        "  multiply_ten = (\n",
        "      p \n",
        "      | read_from_bq \n",
        "      | \"CreateMultiplyTenTuple\" >> beam.Map(lambda x: ('{} {}'.format(x['key'], '* 10'), x['value']))\n",
        "      | \"ConvertNumpyToTensorTenTuple\" >> beam.Map(lambda x: (x[0], torch.Tensor([x[1]])))\n",
        "      | \"RunInferenceTorchTenTuple\" >> RunInference(keyed_torch_ten_times_model_handler)\n",
        "  )\n",
        "\n",
        "  inference_result = ((multiply_five, multiply_ten) | beam.Flatten() \n",
        "                                 | beam.ParDo(PredictionWithKeyProcessor()))\n",
        "  inference_result | beam.Map(print)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e71e6706-5d8d-4322-9def-ac7fb20d4a50"
      },
      "source": [
        "### Multiple models in sequence\n",
        "This section demonstrates how use inference with multiple models in sequence.\n",
        "\n",
        "In a sequential pattern, data is sent to one or more models in sequence, \n",
        "with the output from each model chaining to the next model.\n",
        "The following list demonstrates the sequence used in this section.\n",
        "\n",
        "1. Read the data from BigQuery.\n",
        "2. Map the data.\n",
        "3. Use RunInference with the multiply by 5 model.\n",
        "4. Process the results.\n",
        "5. Use RunInference with the multiply by 10 model.\n",
        "6. Process the results.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 69,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "8db9d649-5549-4b58-a9ad-7b8592c2bcbf",
        "outputId": "328ba32b-40d4-445b-8b4e-5568258b8a26"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "key: original input is `third_question tensor([1000.])`, input: 4962.61962890625 output: 49045.37890625\n",
            "key: original input is `second_question tensor([108.])`, input: 538.472412109375 output: 5329.11083984375\n",
            "key: original input is `first_question tensor([105.])`, input: 523.5929565429688 output: 5182.08251953125\n",
            "key: original input is `fourth_question tensor([1013.])`, input: 5027.0966796875 output: 49682.49609375\n"
          ]
        }
      ],
      "source": [
        "def process_interim_inference(element):\n",
        "    key, prediction_result = element\n",
        "    input_value = prediction_result.example\n",
        "    inference = prediction_result.inference\n",
        "    formatted_input_value = 'original input is `{} {}`'.format(key, input_value)\n",
        "    return formatted_input_value, inference\n",
        "\n",
        "\n",
        "pipeline_options = PipelineOptions().from_dictionary(\n",
        "                                      {'temp_location':f'gs://{bucket}/tmp'})\n",
        "pipeline = beam.Pipeline(options=pipeline_options)\n",
        "\n",
        "with pipeline as p:\n",
        "  multiply_five = (\n",
        "      p \n",
        "      | beam.io.ReadFromBigQuery(table=table_spec) \n",
        "      | \"CreateMultiplyFiveTuple\" >> beam.Map(lambda x: (x['key'], x['value']))\n",
        "      | \"ConvertNumpyToTensorFiveTuple\" >> beam.Map(lambda x: (x[0], torch.Tensor([x[1]])))\n",
        "      | \"RunInferenceTorchFiveTuple\" >> RunInference(keyed_torch_five_times_model_handler)\n",
        "  )\n",
        "\n",
        "  inference_result = (\n",
        "    multiply_five \n",
        "      | \"ExtractResult\" >> beam.Map(process_interim_inference) \n",
        "      | \"RunInferenceTorchTenTuple\" >> RunInference(keyed_torch_ten_times_model_handler)\n",
        "      | beam.ParDo(PredictionWithKeyProcessor())\n",
        "    )\n",
        "  inference_result | beam.Map(print)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
